{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-01-20T12:14:55.514295Z",
     "start_time": "2025-01-20T12:14:55.507655Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "\n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n",
    "\n",
    "seed = 42\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.0\n",
      "torch 2.5.1+cpu\n",
      "cpu\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:14:55.642255Z",
     "start_time": "2025-01-20T12:14:55.609404Z"
    }
   },
   "cell_type": "code",
   "source": "! tree archive",
   "id": "f3ae6eece2066215",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "卷 新加卷 的文件夹 PATH 列表\n",
      "卷序列号为 0000003C FE1F:B69C\n",
      "E:\\01_PYTHON_CODE_2025\\DEEP_LEARNING\\CHAPTER_4_TORCH\\ARCHIVE\n",
      "├─training\n",
      "│  ├─n0\n",
      "│  ├─n1\n",
      "│  ├─n2\n",
      "│  ├─n3\n",
      "│  ├─n4\n",
      "│  ├─n5\n",
      "│  ├─n6\n",
      "│  ├─n7\n",
      "│  ├─n8\n",
      "│  └─n9\n",
      "└─validation\n",
      "    ├─n0\n",
      "    ├─n1\n",
      "    ├─n2\n",
      "    ├─n3\n",
      "    ├─n4\n",
      "    ├─n5\n",
      "    ├─n6\n",
      "    ├─n7\n",
      "    ├─n8\n",
      "    └─n9\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 准备数据",
   "id": "4bbbf89672c591d2"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:14:55.657979Z",
     "start_time": "2025-01-20T12:14:55.645261Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor, Resize, Compose, ConvertImageDtype, Normalize\n",
    "from pathlib import Path\n",
    "\n",
    "DATA_DIR = Path(\"./archive/\")\n",
    "\n",
    "\n",
    "class MonkeyDataset(datasets.ImageFolder):\n",
    "    def __init__(self, mode, transform=None):\n",
    "        if mode == \"train\":\n",
    "            # 使用 / 操作符拼接路径，代码更简洁易读。\n",
    "            root = DATA_DIR / \"training\"\n",
    "        elif mode == \"val\":\n",
    "            root = DATA_DIR / \"validation\"\n",
    "        else:\n",
    "            raise ValueError(\"Invalid mode\")\n",
    "        # 调用父类init方法\n",
    "        super().__init__(root, transform)\n",
    "        # self.imgs = self.samples # self.samples里边是图片路径及标签 [(path, label), (path, label),...]\n",
    "        self.targets = [s[1] for s in self.samples]  # 取出标签列表\n",
    "\n",
    "\n",
    "# 预先设定的图片尺寸\n",
    "img_h, img_w = 128, 128\n",
    "transform = Compose([\n",
    "    Resize((img_h, img_w)),  # 图片缩放\n",
    "    ToTensor(),\n",
    "    # 预先统计的\n",
    "    Normalize([0.4363, 0.4328, 0.3291], [0.2085, 0.2032, 0.1988]),\n",
    "    ConvertImageDtype(torch.float),\n",
    "])\n",
    "\n",
    "train_ds = MonkeyDataset(\"train\", transform=transform)\n",
    "val_ds = MonkeyDataset(\"val\", transform=transform)\n",
    "\n",
    "print(len(train_ds), len(val_ds))  # 训练集和验证集的样本数 1097 272"
   ],
   "id": "ab4648609e29001a",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1097 272\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:14:55.673013Z",
     "start_time": "2025-01-20T12:14:55.669251Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 数据类型\n",
    "train_ds.classes"
   ],
   "id": "57cddcfcb827b9c4",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6', 'n7', 'n8', 'n9']"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 19
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:14:55.693377Z",
     "start_time": "2025-01-20T12:14:55.688039Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# class_to_idx 是一个字典，键是类别名称（字符串），值是对应的索引（整数）\n",
    "# 在训练模型时，通常需要将类别名称（如 'cat'、'dog'）转换为整数索引（如 0、1）。\n",
    "train_ds.class_to_idx"
   ],
   "id": "3c46b191362e3794",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n0': 0,\n",
       " 'n1': 1,\n",
       " 'n2': 2,\n",
       " 'n3': 3,\n",
       " 'n4': 4,\n",
       " 'n5': 5,\n",
       " 'n6': 6,\n",
       " 'n7': 7,\n",
       " 'n8': 8,\n",
       " 'n9': 9}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 20
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:14:55.710478Z",
     "start_time": "2025-01-20T12:14:55.706397Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 图片路径 及 标签\n",
    "i = 0\n",
    "for fpath, label in train_ds.samples:\n",
    "    print(fpath, label)\n",
    "    i += 1\n",
    "    if i == 10:\n",
    "        break"
   ],
   "id": "d68a4b70e23944ee",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "archive\\training\\n0\\n0018.jpg 0\n",
      "archive\\training\\n0\\n0019.jpg 0\n",
      "archive\\training\\n0\\n0020.jpg 0\n",
      "archive\\training\\n0\\n0021.jpg 0\n",
      "archive\\training\\n0\\n0022.jpg 0\n",
      "archive\\training\\n0\\n0023.jpg 0\n",
      "archive\\training\\n0\\n0024.jpg 0\n",
      "archive\\training\\n0\\n0025.jpg 0\n",
      "archive\\training\\n0\\n0026.jpg 0\n",
      "archive\\training\\n0\\n0027.jpg 0\n"
     ]
    }
   ],
   "execution_count": 21
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.802030Z",
     "start_time": "2025-01-20T12:14:55.754504Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 因为有3通道，所有有3个mean和std\n",
    "def calc_mean_std(ds):\n",
    "    mean = 0.\n",
    "    std = 0.\n",
    "    for img, _ in ds:\n",
    "        mean += img[:, :, :].mean(dim=(1, 2))\n",
    "        std += img[:, :, :].std(dim=(1, 2))\n",
    "    mean /= len(ds)\n",
    "    std /= len(ds)\n",
    "    return mean, std\n",
    "\n",
    "\n",
    "print(calc_mean_std(train_ds))"
   ],
   "id": "1dad0ae672d4fee9",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([ 1.5299e-04,  3.6267e-05, -6.5391e-07]), tensor([0.9999, 0.9999, 1.0002]))\n"
     ]
    }
   ],
   "execution_count": 22
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.808303Z",
     "start_time": "2025-01-20T12:15:11.803032Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch.nn as nn\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "batch_size = 64\n",
    "# 从数据集到dataloader，num_workers参数不能加，否则会报错\n",
    "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)\n",
    "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)"
   ],
   "id": "bf91b2881594a04e",
   "outputs": [],
   "execution_count": 23
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "",
   "id": "9baa908563f509d2"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 定义模型",
   "id": "ebf38d30bf72643b"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.873832Z",
     "start_time": "2025-01-20T12:15:11.809304Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self, num_classes=10, activation=\"relu\"):\n",
    "        super(CNN, self).__init__()\n",
    "        self.activation = F.relu if activation == \"relu\" else F.selu\n",
    "        self.flattener = nn.Flatten()  # 展平层\n",
    "        # 卷积层 \n",
    "        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=\"same\")\n",
    "        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=\"same\")\n",
    "        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=\"same\")\n",
    "        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=\"same\")\n",
    "        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=\"same\")\n",
    "        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=\"same\")\n",
    "        # 池化层\n",
    "        # nn.MaxPool2d:这是 PyTorch 中的最大池化层，用于对输入数据进行降采样。\n",
    "        # kernel_size=2：池化核的大小，通常是一个正方形。\n",
    "        self.pool = nn.MaxPool2d(2, 2)  #池化核大小为2（2*2），步长为2\n",
    "        # 全连接层\n",
    "        # (32,128,3,3)->(32,128*3*3)\n",
    "        self.fc1 = nn.Linear(128 * 16 * 16, 128)\n",
    "        self.fc2 = nn.Linear(128, num_classes)\n",
    "\n",
    "        self.init_weights()  # 初始化权重\n",
    "\n",
    "    def init_weights(self):  # 初始化权重\n",
    "        \"\"\"使用 xavier 均匀分布来初始化全连接层、卷积层的权重 W\"\"\"\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, (nn.Linear, nn.Conv2d)):\n",
    "                nn.init.xavier_uniform_(m.weight)\n",
    "                nn.init.zeros_(m.bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        act = self.activation\n",
    "        #          conv1          conv2         pool  \n",
    "        # 32*3*128*128 -> 32*32*128*128-> 32*32*128*128->32*32*64*64\n",
    "        x = self.pool(act(self.conv2(act(self.conv1(x)))))\n",
    "        # print(x.shape)\n",
    "        #           conv3          conv4         pool  \n",
    "        # 32*32*64*64 -> 32*64*64*64 -> 32*64*64*64 -> 32*64*32*32\n",
    "        x = self.pool(act(self.conv4(act(self.conv3(x)))))\n",
    "        # print(x.shape)\n",
    "        #           conv5          conv6         pool  \n",
    "        # 32*64*32*32 -> 32*128*32*32 -> 32*128*32*32 -> 32*128*16*16\n",
    "        x = self.pool(act(self.conv6(act(self.conv5(x)))))\n",
    "        # print(x.shape)\n",
    "\n",
    "        # 32*128*3*3 ->32*(128*3*3)\n",
    "        x = self.flattener(x)  #展平\n",
    "        # 32*(128*3*3)->32*128\n",
    "        x = act(self.fc1(x))\n",
    "        # 32*128->32*10\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "for idx, (key, value) in enumerate(CNN().named_parameters()):\n",
    "    print(f\"{key}\\tparamerters num: {np.prod(value.shape)}\")  # 打印模型的参数信息"
   ],
   "id": "88c29638f533a700",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.weight\tparamerters num: 864\n",
      "conv1.bias\tparamerters num: 32\n",
      "conv2.weight\tparamerters num: 9216\n",
      "conv2.bias\tparamerters num: 32\n",
      "conv3.weight\tparamerters num: 18432\n",
      "conv3.bias\tparamerters num: 64\n",
      "conv4.weight\tparamerters num: 36864\n",
      "conv4.bias\tparamerters num: 64\n",
      "conv5.weight\tparamerters num: 73728\n",
      "conv5.bias\tparamerters num: 128\n",
      "conv6.weight\tparamerters num: 147456\n",
      "conv6.bias\tparamerters num: 128\n",
      "fc1.weight\tparamerters num: 4194304\n",
      "fc1.bias\tparamerters num: 128\n",
      "fc2.weight\tparamerters num: 1280\n",
      "fc2.bias\tparamerters num: 10\n"
     ]
    }
   ],
   "execution_count": 24
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.926535Z",
     "start_time": "2025-01-20T12:15:11.875833Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 计算模型总参数量\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "\n",
    "count_parameters(CNN())"
   ],
   "id": "a27564a6576cf8e5",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4482730"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 25
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 训练模型",
   "id": "10cfa4a7bac92bbc"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.932806Z",
     "start_time": "2025-01-20T12:15:11.927536Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "\n",
    "@torch.no_grad()  # 装饰器，禁止梯度计算\n",
    "def evaluate(model, data_loader, loss_fct):\n",
    "    loss_list = []\n",
    "    pred_list = []\n",
    "    label_list = []\n",
    "    for datas, labels in data_loader:\n",
    "        datas = datas.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # 前向传播\n",
    "        logits = model(datas)\n",
    "        loss = loss_fct(logits, labels)  # 验证集损失\n",
    "        # tensor.item() 获取tensor的数值，loss是只有一个元素的tensor\n",
    "        loss_list.append(loss.item())\n",
    "\n",
    "        # 预测\n",
    "        preds = logits.argmax(axis=-1)  # 预测类别\n",
    "        pred_list.extend(preds.cpu().numpy().tolist())  # tensor转numpy，再转list\n",
    "        label_list.extend(labels.cpu().numpy().tolist())\n",
    "\n",
    "    acc = accuracy_score(label_list, pred_list)  # 计算准确率\n",
    "    return np.mean(loss_list), acc  # # 返回验证集平均损失和准确率"
   ],
   "id": "c20254bb6b07dcc2",
   "outputs": [],
   "execution_count": 26
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.939272Z",
     "start_time": "2025-01-20T12:15:11.933808Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=500, save_best_only=True):\n",
    "        self.save_dir = save_dir  # 保存路径\n",
    "        self.save_step = save_step  # 保存步数\n",
    "        self.save_best_only = save_best_only  # 是否只保存最好的模型\n",
    "        self.best_metric = -1  # 最好的指标，指标不可能为负数，所以初始化为-1\n",
    "        # 创建保存路径\n",
    "        if not os.path.exists(self.save_dir):  # 如果不存在保存路径，则创建\n",
    "            os.makedirs(self.save_dir)\n",
    "\n",
    "    # 对象被调用时：当你将对象像函数一样调用时，Python 会自动调用 __call__ 方法。\n",
    "    # state_dict() 返回模型参数的字典，包括模型参数和优化器参数\n",
    "    # metric 是指标，可以是验证集的准确率，也可以是其他指标\n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return  # 不是保存步数，则直接返回\n",
    "\n",
    "        if self.save_best_only:\n",
    "            assert metric is not None  # 必须传入metric\n",
    "            if metric >= self.best_metric:  # 如果当前指标大于最好的指标\n",
    "                # save checkpoint\n",
    "                # 保存最好的模型，覆盖之前的模型，不保存step，只保存state_dict，即模型参数，不保存优化器参数\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"04_monkey_best.ckpt\"))\n",
    "                self.best_metric = metric  # 更新最好的指标\n",
    "        else:\n",
    "            # 保存模型\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "            # 保存每个step的模型，不覆盖之前的模型，保存step，保存state_dict，即模型参数，不保存优化器参数"
   ],
   "id": "2020775d13810cd8",
   "outputs": [],
   "execution_count": 27
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.945586Z",
     "start_time": "2025-01-20T12:15:11.940273Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class EarlyStopCallback:\n",
    "    def __init__(self, patience=5, min_delta=0.01):\n",
    "        self.patience = patience  # 多少个step没有提升就停止训练\n",
    "        self.min_delta = min_delta  # 最小的提升幅度\n",
    "        self.best_metric = -1  # 记录的最好的指标\n",
    "        self.counter = 0  # 计数器，记录连续多少个step没有提升\n",
    "\n",
    "    def __call__(self, metric):\n",
    "        if metric >= self.best_metric + self.min_delta:  # 如果指标提升了\n",
    "            self.best_metric = metric  # 更新最好的指标\n",
    "            self.counter = 0  # 计数器清零\n",
    "        else:\n",
    "            self.counter += 1  # 计数器加一\n",
    "\n",
    "    @property  # 使用@property装饰器，使得 对象.early_stop可以调用，不需要()\n",
    "    def early_stop(self):\n",
    "        # 如果计数器大于等于patience，则返回True，停止训练\n",
    "        return self.counter >= self.patience"
   ],
   "id": "8ada7ce78dca1e58",
   "outputs": [],
   "execution_count": 28
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.954421Z",
     "start_time": "2025-01-20T12:15:11.946588Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def training(model,\n",
    "             train_loader,\n",
    "             val_loader,\n",
    "             epoch,\n",
    "             loss_fct,\n",
    "             optimizer,\n",
    "             save_ckpt_callback=None,\n",
    "             early_stop_callback=None,\n",
    "             eval_step=500,\n",
    "             ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "        \"val\": []\n",
    "    }\n",
    "\n",
    "    global_step = 0  # 全局步数\n",
    "    model.train()  # 训练模式\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                # 前向传播\n",
    "                logits = model(datas)\n",
    "                loss = loss_fct(logits, labels)  # 训练集损失\n",
    "                preds = logits.argmax(axis=-1)  # 预测类别\n",
    "\n",
    "                # 反向传播\n",
    "                optimizer.zero_grad()  # 梯度清零\n",
    "                loss.backward()  # 反向传播\n",
    "                optimizer.step()  # 优化器更新参数\n",
    "\n",
    "                # 计算准确率\n",
    "                acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())\n",
    "                loss = loss.cpu().item()\n",
    "\n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss,\n",
    "                    \"acc\": acc,\n",
    "                    \"step\": global_step\n",
    "                })\n",
    "\n",
    "                # 评估\n",
    "                if global_step % eval_step == 0:\n",
    "                    model.eval()  # 评估模式\n",
    "                    # 验证集损失和准确率\n",
    "                    val_loss, val_acc = evaluate(model, val_loader, loss_fct)\n",
    "                    record_dict[\"val\"].append({\n",
    "                        \"loss\": val_loss,\n",
    "                        \"acc\": val_acc,\n",
    "                        \"step\": global_step\n",
    "                    })\n",
    "                    model.train()  # 训练模式\n",
    "\n",
    "                    # 2. 保存模型权重 save model checkpoint\n",
    "                    if save_ckpt_callback is not None:\n",
    "                        # model.state_dict() 返回模型参数的字典，包括模型参数和优化器参数\n",
    "                        save_ckpt_callback(global_step, model.state_dict(), val_acc)\n",
    "                        # 保存最好的模型，覆盖之前的模型，保存step，保存state_dict,通过metric判断是否保存最好的模型\n",
    "\n",
    "                    # 3. 早停 early stopping\n",
    "                    if early_stop_callback is not None:\n",
    "                        # 验证集准确率不再提升，则停止训练\n",
    "                        early_stop_callback(val_acc)\n",
    "                        # 验证集准确率不再提升，则停止训练\n",
    "                        if early_stop_callback.early_stop:\n",
    "                            print(f\"Early stop at epoch {epoch_id} / global_step {global_step}\")\n",
    "                            return record_dict  # 早停，返回记录字典 record_dict\n",
    "\n",
    "                # 更新进度条和全局步数\n",
    "                pbar.update(1)  # 更新进度条\n",
    "                global_step += 1  # 全局步数加一\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "\n",
    "    return record_dict  # 训练结束，返回记录字典 record_dict"
   ],
   "id": "76521918c12cc25b",
   "outputs": [],
   "execution_count": 29
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:15:11.996415Z",
     "start_time": "2025-01-20T12:15:11.955423Z"
    }
   },
   "cell_type": "code",
   "source": [
    "epoch = 100\n",
    "\n",
    "activation = \"selu\"  # 激活函数\n",
    "model = CNN(num_classes=len(train_ds.classes), activation=activation)  # 定义模型\n",
    "\n",
    "# 1. 定义损失函数 采用MSE损失\n",
    "loss_fct = nn.CrossEntropyLoss()\n",
    "\n",
    "# 2. 定义优化器 采用SGD优化器\n",
    "# eps=1e-7 防止除0错误\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001, eps=1e-7)\n",
    "\n",
    "# 3.save model checkpoint\n",
    "if not os.path.exists(\"checkpoints\"):\n",
    "    os.makedirs(\"checkpoints\")\n",
    "save_ckpt_callback = SaveCheckpointsCallback(save_dir=\"checkpoints\", save_step=len(train_loader), save_best_only=True)\n",
    "\n",
    "# 4. early stopping\n",
    "early_stop_callback = EarlyStopCallback(patience=10, min_delta=0.01)"
   ],
   "id": "4f75739b23ef50ef",
   "outputs": [],
   "execution_count": 30
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:16:02.782716Z",
     "start_time": "2025-01-20T12:15:11.998418Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model = model.to(device)  # 将模型移到GPU上\n",
    "\n",
    "# 训练过程\n",
    "record_dict = training(\n",
    "    model,\n",
    "    train_loader,\n",
    "    val_loader,\n",
    "    epoch,\n",
    "    loss_fct,\n",
    "    optimizer,\n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    "    early_stop_callback=early_stop_callback,\n",
    "    eval_step=len(train_loader)\n",
    ")"
   ],
   "id": "1a65ef232d245b75",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/1800 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "0f9c59d0b0854072991979032a82f289"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[31], line 4\u001B[0m\n\u001B[0;32m      1\u001B[0m model \u001B[38;5;241m=\u001B[39m model\u001B[38;5;241m.\u001B[39mto(device)  \u001B[38;5;66;03m# 将模型移到GPU上\u001B[39;00m\n\u001B[0;32m      3\u001B[0m \u001B[38;5;66;03m# 训练过程\u001B[39;00m\n\u001B[1;32m----> 4\u001B[0m record_dict \u001B[38;5;241m=\u001B[39m \u001B[43mtraining\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m      5\u001B[0m \u001B[43m    \u001B[49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      6\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtrain_loader\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      7\u001B[0m \u001B[43m    \u001B[49m\u001B[43mval_loader\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      8\u001B[0m \u001B[43m    \u001B[49m\u001B[43mepoch\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m      9\u001B[0m \u001B[43m    \u001B[49m\u001B[43mloss_fct\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     10\u001B[0m \u001B[43m    \u001B[49m\u001B[43moptimizer\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     11\u001B[0m \u001B[43m    \u001B[49m\u001B[43msave_ckpt_callback\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msave_ckpt_callback\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     12\u001B[0m \u001B[43m    \u001B[49m\u001B[43mearly_stop_callback\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mearly_stop_callback\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m     13\u001B[0m \u001B[43m    \u001B[49m\u001B[43meval_step\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mlen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mtrain_loader\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     14\u001B[0m \u001B[43m)\u001B[49m\n",
      "Cell \u001B[1;32mIn[29], line 48\u001B[0m, in \u001B[0;36mtraining\u001B[1;34m(model, train_loader, val_loader, epoch, loss_fct, optimizer, save_ckpt_callback, early_stop_callback, eval_step)\u001B[0m\n\u001B[0;32m     46\u001B[0m model\u001B[38;5;241m.\u001B[39meval()  \u001B[38;5;66;03m# 评估模式\u001B[39;00m\n\u001B[0;32m     47\u001B[0m \u001B[38;5;66;03m# 验证集损失和准确率\u001B[39;00m\n\u001B[1;32m---> 48\u001B[0m val_loss, val_acc \u001B[38;5;241m=\u001B[39m \u001B[43mevaluate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_loader\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mloss_fct\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     49\u001B[0m record_dict[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mval\u001B[39m\u001B[38;5;124m\"\u001B[39m]\u001B[38;5;241m.\u001B[39mappend({\n\u001B[0;32m     50\u001B[0m     \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m\"\u001B[39m: val_loss,\n\u001B[0;32m     51\u001B[0m     \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124macc\u001B[39m\u001B[38;5;124m\"\u001B[39m: val_acc,\n\u001B[0;32m     52\u001B[0m     \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mstep\u001B[39m\u001B[38;5;124m\"\u001B[39m: global_step\n\u001B[0;32m     53\u001B[0m })\n\u001B[0;32m     54\u001B[0m model\u001B[38;5;241m.\u001B[39mtrain()  \u001B[38;5;66;03m# 训练模式\u001B[39;00m\n",
      "File \u001B[1;32mC:\\Program Files\\Python312\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001B[0m, in \u001B[0;36mcontext_decorator.<locals>.decorate_context\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m    113\u001B[0m \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[0;32m    114\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mdecorate_context\u001B[39m(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m    115\u001B[0m     \u001B[38;5;28;01mwith\u001B[39;00m ctx_factory():\n\u001B[1;32m--> 116\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "Cell \u001B[1;32mIn[26], line 14\u001B[0m, in \u001B[0;36mevaluate\u001B[1;34m(model, data_loader, loss_fct)\u001B[0m\n\u001B[0;32m     11\u001B[0m labels \u001B[38;5;241m=\u001B[39m labels\u001B[38;5;241m.\u001B[39mto(device)\n\u001B[0;32m     13\u001B[0m \u001B[38;5;66;03m# 前向传播\u001B[39;00m\n\u001B[1;32m---> 14\u001B[0m logits \u001B[38;5;241m=\u001B[39m \u001B[43mmodel\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdatas\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     15\u001B[0m loss \u001B[38;5;241m=\u001B[39m loss_fct(logits, labels)  \u001B[38;5;66;03m# 验证集损失\u001B[39;00m\n\u001B[0;32m     16\u001B[0m \u001B[38;5;66;03m# tensor.item() 获取tensor的数值，loss是只有一个元素的tensor\u001B[39;00m\n",
      "File \u001B[1;32mC:\\Program Files\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m   1734\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)  \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m   1735\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1736\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32mC:\\Program Files\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m   1742\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m   1743\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m   1744\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m   1745\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m   1746\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1747\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m   1749\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m   1750\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
      "Cell \u001B[1;32mIn[24], line 35\u001B[0m, in \u001B[0;36mCNN.forward\u001B[1;34m(self, x)\u001B[0m\n\u001B[0;32m     32\u001B[0m act \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mactivation\n\u001B[0;32m     33\u001B[0m \u001B[38;5;66;03m#          conv1          conv2         pool  \u001B[39;00m\n\u001B[0;32m     34\u001B[0m \u001B[38;5;66;03m# 32*3*128*128 -> 32*32*128*128-> 32*32*128*128->32*32*64*64\u001B[39;00m\n\u001B[1;32m---> 35\u001B[0m x \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mpool(act(\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconv2\u001B[49m\u001B[43m(\u001B[49m\u001B[43mact\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconv1\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m))\n\u001B[0;32m     36\u001B[0m \u001B[38;5;66;03m# print(x.shape)\u001B[39;00m\n\u001B[0;32m     37\u001B[0m \u001B[38;5;66;03m#           conv3          conv4         pool  \u001B[39;00m\n\u001B[0;32m     38\u001B[0m \u001B[38;5;66;03m# 32*32*64*64 -> 32*64*64*64 -> 32*64*64*64 -> 32*64*32*32\u001B[39;00m\n\u001B[0;32m     39\u001B[0m x \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mpool(act(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mconv4(act(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mconv3(x)))))\n",
      "File \u001B[1;32mC:\\Program Files\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m   1734\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)  \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m   1735\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1736\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32mC:\\Program Files\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m   1742\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m   1743\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m   1744\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m   1745\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m   1746\u001B[0m         \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1747\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m   1749\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m   1750\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
      "File \u001B[1;32mC:\\Program Files\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:554\u001B[0m, in \u001B[0;36mConv2d.forward\u001B[1;34m(self, input)\u001B[0m\n\u001B[0;32m    553\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m: Tensor) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Tensor:\n\u001B[1;32m--> 554\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_conv_forward\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mweight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbias\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32mC:\\Program Files\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:549\u001B[0m, in \u001B[0;36mConv2d._conv_forward\u001B[1;34m(self, input, weight, bias)\u001B[0m\n\u001B[0;32m    537\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mpadding_mode \u001B[38;5;241m!=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mzeros\u001B[39m\u001B[38;5;124m\"\u001B[39m:\n\u001B[0;32m    538\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m F\u001B[38;5;241m.\u001B[39mconv2d(\n\u001B[0;32m    539\u001B[0m         F\u001B[38;5;241m.\u001B[39mpad(\n\u001B[0;32m    540\u001B[0m             \u001B[38;5;28minput\u001B[39m, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_reversed_padding_repeated_twice, mode\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mpadding_mode\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m    547\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mgroups,\n\u001B[0;32m    548\u001B[0m     )\n\u001B[1;32m--> 549\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mF\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconv2d\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m    550\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbias\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstride\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpadding\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdilation\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgroups\u001B[49m\n\u001B[0;32m    551\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "\u001B[1;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "execution_count": 31
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def plot_record_curves(record_dict, sample_step=500):\n",
    "    # .set_index(\"step\") 将 step 列设置为 DataFrame 的索引\n",
    "    train_df = pd.DataFrame(record_dict[\"train\"]).set_index(\"step\").iloc[::sample_step]\n",
    "    val_df = pd.DataFrame(record_dict[\"val\"]).set_index(\"step\")\n",
    "\n",
    "    last_step = train_df.index[-1]  # 最后一步的步数\n",
    "\n",
    "    print(train_df)\n",
    "    print(val_df)\n",
    "\n",
    "    # 画图 \n",
    "    fig_num = len(train_df.columns)  # 画两张图,分别是损失和准确率\n",
    "\n",
    "    # plt.subplots：用于创建一个包含多个子图的图形窗口。\n",
    "    # 1：表示子图的行数为 1。\n",
    "    # fig_num：表示子图的列数，即子图的数量。\n",
    "    # figsize=(5 * fig_num, 5)：设置整个图形窗口的大小，宽度为 5 * fig_num，高度为 5。\n",
    "    # fig：返回的图形对象（Figure），用于操作整个图形窗口。\n",
    "    # axs：返回的子图对象（Axes 或 Axes 数组），用于操作每个子图。\n",
    "    fig, axs = plt.subplots(1, fig_num, figsize=(5 * fig_num, 5))\n",
    "    for idx, item in enumerate(train_df.columns):\n",
    "        # train_df.index 是 x 轴数据（通常是 step）。\n",
    "        # train_df[item] 是 y 轴数据（当前指标的值）。\n",
    "        axs[idx].plot(train_df.index, train_df[item], label=\"train:\" + item)\n",
    "        # val_df.index 是 x 轴数据。\n",
    "        # val_df[item] 是 y 轴数据。\n",
    "        axs[idx].plot(val_df.index, val_df[item], label=\"val:\" + item)\n",
    "        axs[idx].grid()  # 显示网格\n",
    "        axs[idx].legend()  # 显示图例\n",
    "        axs[idx].set_xticks(range(0, train_df.index[-1]+1, 5000))  # 设置x轴刻度\n",
    "        axs[idx].set_xticklabels(map(lambda x: f\"{x // 1000}k\", range(0, last_step + 1, 5000)))  # 设置x轴标签\n",
    "        axs[idx].set_xlabel(\"step\")\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_record_curves(record_dict)"
   ],
   "id": "3146342be350e25",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 评估",
   "id": "9754bfd8dbbfc3bf"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-20T12:16:02.785719Z",
     "start_time": "2025-01-20T12:16:02.784719Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model = CNN(activation)  #上线时加载模型\n",
    "model = model.to(device)  # 将模型移到GPU上\n",
    "\n",
    "# 加载最好的模型\n",
    "# torch.load：加载保存的模型权重或整个模型。\n",
    "# \"checkpoints/best.ckpt\"：模型权重文件路径。\n",
    "# weights_only=True：仅加载模型的权重，而不是整个模型（包括结构和参数）。这是 PyTorch 2.1 引入的新特性，用于增强安全性。\n",
    "# map_location=device：将模型加载到当前设备（GPU或CPU）。\n",
    "model.load_state_dict(torch.load(\"checkpoints/04_monkey_best.ckpt\", weights_only=True, map_location=device))  # 加载最好的模型\n",
    "\n",
    "model.eval()  # 评估模式\n",
    "loss, acc = evaluate(model, val_loader, loss_fct)\n",
    "print(f\"Test loss: {loss:.4f}, Test acc: {acc:.4f}\")"
   ],
   "id": "be0bdf9c653f2ea1",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
